Skip to content

Enable native AMD ROCm inference via Triton kernels#166

Merged
jandom merged 5 commits intoaqlaboratory:mainfrom
singagan:amd-rocm-inference
Apr 9, 2026
Merged

Enable native AMD ROCm inference via Triton kernels#166
jandom merged 5 commits intoaqlaboratory:mainfrom
singagan:amd-rocm-inference

Conversation

@singagan
Copy link
Copy Markdown
Contributor

@singagan singagan commented Apr 5, 2026

Enable native AMD ROCm inference via Triton kernels

Summary

OpenFold3 now supports high-performance native inference on AMD GPUs with ROCm.

OpenFold3 previously lacked a high-performance inference path on AMD GPUs
because Evoformer attention and TriangleMultiplicativeUpdate depended on
CUDA-specific kernels with no ROCm support. This PR unlocks native AMD
inference by replacing those dependencies with Triton kernels and wiring the
new path through the full OpenFold3 model stack.

It also adds ROCm validation, kernel correctness tests, and ready-to-use AMD
installation and runtime configuration, making native AMD inference practical
out of the box.

Changes

  • add Triton Evoformer attention with i64 overflow-safe indexing for long sequences
  • add fused Triton kernels for layernorm, linear, and sigmoid-gated linear used
    in the TriangleMultiplicativeUpdate inference path
  • thread use_triton_triangle_kernels through the full model stack, matching
    the existing use_deepspeed_evo_attention/use_cueq_triangle_kernels pattern
  • add ROCm backend validation
  • add 9 kernel tests covering forward and backward correctness in bf16 and fp32
  • add examples/example_runner_yamls/triton.yml as a ready-to-use runner config
  • add optional installation support via pip install openfold3[rocm]
  • add environments/production-amd-linux-64.yml for AMD conda environments

Related Issues

  • N/A

Testing

  • validated ROCm backend execution
  • ran 9 kernel correctness tests in bf16 and fp32
  • verified AMD-specific install and runtime configuration

Adds native AMD ROCm inference support through Triton kernels,
including Evoformer attention and TriangleMultiplicativeUpdate inference
kernels, along with validation, tests, and AMD-specific install and runtime
configuration.

- add Triton Evoformer attention with i64 overflow-safe indexing for long sequences
- add fused Triton kernels for TriangleMultiplicativeUpdate inference
- thread `use_triton_triangle_kernels` through the model stack, matching the existing use_deepspeed_evo_attention pattern
- validate ROCm backend execution
- add kernel tests covering forward and backward correctness in bf16 and fp32
- add ready-to-use AMD runner and environment configs
- add optional installation support via `pip install openfold3[rocm]`
@jnwei jnwei linked an issue Apr 6, 2026 that may be closed by this pull request
Copy link
Copy Markdown
Contributor

@jnwei jnwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for this contribution @singagan ! It is very thorough complete with tests. I know several users are looking forward to this addtiion.

I have 2 questions and a few review requests relating to documentation


Question 1: Installation
I understand there is a challenge with the rocm installation because the RocM wheels are only available from the version of pytorch provided through the extra index, e.g. pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2.

In other words, a single line installation for openfold3 and rocm support is not possible at this time.

Pinging @sdvillal and @Emrys-Merlin to see if they have thoughts on an elegant solution for how to handle the extra dependencies. We do not need to add this solution to this PR, but it would be good to think about what we'd like to aim for as a best outcome.

Question 2: Kernel Compilation
I see that the evoformer kernel appears to some caching and compilation behavior (gated by tl.heuristcs) to optimize the kernel sizes.

A few questions about this behavior:

  • In practice, a user of OpenFold may submit queries of many different sequences with different sequence length. In this case, is compilation and caching performed for each unique sequence length? If so, how much extra time to these compilation steps add to the overall worfklow?
  • Is there a way to skip the compilation step all together if the user prefers to skip these steps?

I have limited experience with triton kernels compilations, so please feel free to correct any misunderstandings / refer to me documentation about compilation.


Documentation requests:

Thank you for providing examples and comments for how to run with RocM. To make the AMD compatibility modes more visible, it would be best to add these instructions to the main documentation.

In particular:

  • The installation documentation, with the custom pip install pytorch command, may be added here.

  • The selection of the inference mode with the triton kernels could be added to the inference document, as one of the inference modes, in this section (readthedocs) (raw)

pyproject.toml:

Since the option pip install openfold3[rocm] doesn't add any relevant dependencies, I think we should remove the option from pyproject.toml for now.

One reason we might keep the option is if we can add validation logic to the installation option, to check if pytorch was installed using the extra rocm indices. Perhaps @sdvillal and @Emrys-Merlin will also have thoughts here.

@Emrys-Merlin
Copy link
Copy Markdown

Emrys-Merlin commented Apr 7, 2026

Tanks for pinging us.

@sdvillall and I had a look at the PR and we think that we should be able to add a ROCm-specific environment (e.g., called openfold3-rocm) to the pixi setup. Then users could choose the appropriate environment via

pixi run -e openfold3-rocm run_openfold ...
pixi run -e openfold3-cuda12 run_openfold ...
pixi run -e openfold3-cuda13 run_openfold ...
pixi run -e opnefold3-cpu run_openfold ...

Most of the groundwork should already be there. It's a pity that ROCm is not yet on conda-forge, but we should be able to get everything we need from PyPI (with the ROCm-PyTorch-index).

I will try to cherrypick this PR ontop of the pixi beta branch and see if I can get a working environment. However, I'm not sure I will get around to it this week. Also, I might need help testing as I don't have direct access to AMD accelerators. I will come back to you if that becomes a blocker.

Let me know if that goes in the direction you were thinking @jnwei.

- document ROCm install steps and Triton inference mode in Installation.md and inference.md
- add validate-openfold3-rocm console script to verify ROCm environment after install
- remove empty openfold3[rocm] pip extra in favour of plain pip install openfold3
@singagan singagan force-pushed the amd-rocm-inference branch from d06f8c5 to c7dd1cc Compare April 7, 2026 20:04
@singagan
Copy link
Copy Markdown
Contributor Author

singagan commented Apr 7, 2026

Thank you for the detailed review @jnwei.

In practice, a user of OpenFold may submit queries of many different sequences with different sequence length. In this case, is compilation and caching performed for each unique sequence length? If so, how much extra time to these compilation steps add to the overall worfklow?

Triton JIT-compiles a separate kernel variant for each unique sequence length, since SEQ_LEN is a compile-time constant (tl.constexpr) that gets specialized into the generated GPU code. For a fixed model checkpoint, the head count and embedding dimension are constant, so the main source of new compilations is the sequence length varying across queries.

Triton caches compiled kernels to disk (~/.triton/cache) and reuses them across process restarts. Compilation is therefore a one-time cost per unique sequence length per machine. A cold-cache run takes roughly a few seconds longer than a warm-cache run at the same sequence length. For any repeated workload or subsequent runs, the overhead is zero, and the kernel executes at native GPU speed.

Regarding tl.heuristics: this is unrelated to compilation overhead. It pre-computes three boolean flags at kernel launch time (EVEN_Q, EVEN_KV, EVEN_DIM) and folds them in as compile-time constants to eliminate boundary-check branches when the sequence length divides cleanly into the tile size. It is simply a runtime optimization.

Is there a way to skip the compilation step all together if the user prefers to skip these steps?

Not directly. JIT compilation is intrinsic to how Triton generates GPU-native code and cannot be bypassed. However, since the compiled kernels are cached to disk, the cost is paid once per unique sequence length per machine and never again. Users who need predictable latency from the very first query can pre-warm the cache by running a short dummy forward pass at each expected sequence length before submitting real queries. For any repeated workload, the compilation overhead is zero.

Since the option pip install openfold3[rocm] doesn't add any relevant dependencies, I think we should remove the option from pyproject.toml for now.
One reason we might keep the option is if we can add validation logic to the installation option, to check if pytorch was installed using the extra rocm indices. Perhaps @sdvillal and @Emrys-Merlin will also have thoughts here.

Earlier, I was planning to keep the option as a hook for future PyTorch wheel support, but I decided to remove the empty openfold3[rocm] pip extra in favor of plain pip install openfold3. I added and documented a validate-openfold3-rocm console script to verify the ROCm environment after install.

I agree with @Emrys-Merlin that adding a pixi environment in a follow-up PR would be a cleaner solution. environments/production-amd-linux-64.yml we added in this PR could hopefully serve as a blueprint for the openfold3-rocm pixi environment.

@Emrys-Merlin, happy to help with testing and verification whenever you need access to AMD hardware.

@jnwei jnwei added the safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. label Apr 8, 2026
@jnwei
Copy link
Copy Markdown
Contributor

jnwei commented Apr 8, 2026

@singagan Thank you for the explanations regarding the compilation and caching of the kernels. I see that the kernel caching will speed up inference long term, so long as the triton cache remains. I wonder if there is a recommended way for working with triton caches for users who may have transient access to compute, for example, if a user runs primarily with AWS instances.

@Emrys-Merlin , your suggestions the specifc pixi environment setups sound great! I think that is a great direction to aim for, we can aim to have a follow up PR that includes RocM.


@singagan I have now had the chance to run the full set of unit tests on an AMD GPU (MI210), and I observe a few failures. I am not sure if you have observed these errors in your testing. I had followed the instructions added to the Installation documentation, and verified that validate-openfold3-rocm checks pass. Perhaps I have made a mistake somewhere.

Deepspeed installation / model configuration issue (openfold3/tests/test_of3_model.py and openfold3/tests/test_kernels.py)

Full Error message RuntimeError: Unable to JIT load the evoformer_attn op due to it not being compatible due to hardware/software issue. None
Full list of failures openfold3/tests/test_kernels.py::TestKernels::test_compare_diffusion_transformer_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_diffusion_transformer_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_pairformer_dsk_fp32_chunk openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_bf16 openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_fp32 openfold3/tests/test_kernels.py::TestKernels::test_compare_template_stack_dsk_fp32_chunk openfold3/tests/test_kernels.py::TestKernels::test_dsk_backward_bf16 openfold3/tests/test_kernels.py::TestKernels::test_dsk_backward_fp32 openfold3/tests/test_kernels.py::TestKernels::test_dsk_forward_bf16 openfold3/tests/test_kernels.py::TestKernels::test_dsk_forward_fp32 openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=train-dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=train-dtype=torch.bfloat16] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=eval-dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_kernels[model=eval-dtype=torch.bfloat16] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_large_eval[dtype=torch.float32] openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_large_eval[dtype=torch.bfloat16]
  • OpenFold may eventually remove support for deepspeed, which would eventually resolve most of these errors. However, the changes required to remove deepspeed may take some time. Until those changes are merged, I'd like to suggest a few small workarounds:
    • Kernel tests: Add a decorator (similar to @compare_utils.skip_unless_triton_installed() for deepspeed tests that skips the test specific to deepspeed if rocm is detected
    • OF3 model tests: The configuration settings for triton may have to be configured manually, similar to the "reduce model size" settings:
      if reduce_model_size:
      # To avoid memory issues in CI
      config.architecture.pairformer.no_blocks = 4
      config.architecture.diffusion_module.diffusion_transformer.no_blocks = 4
- Another option for configuring triton settings is to create a [model_preset](https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/projects/of3_all_atom/config/model_setting_presets.yml) for triton. Basically, you may take the contents of the example `triton.yml` and add them to the `model_settings_presets.yml`, something like this:
Example preset addition triton: settings; memory: eval: use_triton_triangle_kernels: true use_deepspeed_evo_attention: false use_cueq_triangle_kernels: false

Then in test_of3_model.py, the preset can be applied with

project_entry = OF3ProjectEntry()
config = project_entry.get_model_config_with_presets(["eval", "triton"])

And to run inference with the triton kernels, the triton.yml simplifies to

model_update:
  presets:
    - predict
    - triton  # defined in projects.of3_all_atom.model_setting_presets.yml
	- low_mem  # to use low memory settings

LMDB cache writing error ( openfold3/tests/test_lmdb.py)

Full error test_lmdb.py: lmdb.Error: The environment '/tmp/pytest-of-jwei22/pytest-1/test_lmdb_roundtrip0/test_lmdb' is already open in this process.
- We have observed this issue on other builds (see #143) and at this time, we believe this error has more to do with the python version rather than the AMD build. We will address this in a separate PR.

jnwei and others added 2 commits April 8, 2026 22:53
…tests

- add skip_if_rocm() decorator in compare_utils.py
- update skip_unless_ds4s_installed() to also skip on ROCm/HIP
- add use_triton_triangle_kernels param to run_model in test_of3_model.py
- test_shape_small_kernels, test_shape_large_eval, test_shape_large_bf16_train
  now use Triton kernels on ROCm instead of failing on DeepSpeed
@singagan
Copy link
Copy Markdown
Contributor Author

singagan commented Apr 8, 2026

@jnwei Thank you for the detailed report. I have pushed updates for DeepSpeed-related test on ROCm.

  • Added skip_if_rocm() decorator in compare_utils.py and updated skip_unless_ds4s_installed() to also skip on ROCm/HIP, which handles all the test_kernels.py failures automatically
  • For test_of3_model.py, followed your suggestion and configured Triton settings manually in run_model, similar to the reduce_model_size pattern. On ROCm, test_shape_small_kernels, test_shape_large_eval, and test_shape_large_bf16_train now use use_triton_triangle_kernels=True instead of use_deepspeed_evo_attention=True, so the tests actually run and validate the Triton path rather than being skipped.

I went with the manual configuration approach rather than the model preset to keep the changes minimal for now.

LMDB cache writing error (test_lmdb.py): Observed this as well. Good to know this is a known issue tied to Python version rather than the AMD build. Happy to defer to the separate PR.

@jnwei
Copy link
Copy Markdown
Contributor

jnwei commented Apr 9, 2026

Thank you for adding the configuration changes for test_kernels.py and test_of3_models.py

I see that test_kernels.py on AMD now correctly skips all the tests except the triton tests.

For the test_of3_models.py unit tests, I observe the following errors on all of the tests with the recent additions. My read of the error is that the Linear layer is not configured the fused triton linear modules that are specified in the PR description. Did you encounter this error?

FAILED openfold3/tests/test_of3_model.py::TestOF3Model::test_shape_small_fp32[model=train] - RuntimeError: CUDA error: HIPBLAS_STATUS_INTERNAL_ERROR when calling hipblasLtMatmul with transpose_mat1 1 transpose_ma...

Full stack trace:

openfold3/tests/test_of3_model.py:118: in run_model
    batch, outputs = of3(batch=batch)
                     ^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openfold3/core/runners/model_runner.py:58: in forward
    return self.model(batch)
           ^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openfold3/projects/of3_all_atom/model.py:651: in forward
    si_input, si_trunk, zij_trunk = self.run_trunk(
openfold3/projects/of3_all_atom/model.py:200: in run_trunk
    s_input, s_init, z_init = self.input_embedder(
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openfold3/core/model/feature_embedders/input_embedders.py:128: in forward
    a, _, _, _ = self.atom_attn_enc(
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openfold3/core/model/layers/sequence_local_atom_attention.py:526: in forward
    ql, cl, plm = checkpoint_section(
openfold3/core/utils/checkpointing.py:152: in checkpoint_section
    return exec(fn, args)
           ^^^^^^^^^^^^^^
openfold3/core/utils/checkpointing.py:146: in exec
    return fn(*a)
           ^^^^^^
openfold3/core/model/layers/sequence_local_atom_attention.py:428: in get_atom_reps
    cl, plm = self.ref_atom_feature_embedder(
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
openfold3/core/model/layers/sequence_local_atom_attention.py:119: in forward
    cl = self.linear_ref_pos(batch["ref_pos"])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = Linear(in_features=3, out_features=128, bias=False)
input = tensor([[[ 0.8304, -0.3758,  0.8874],
         [-0.8095, -0.6798, -0.2348],
         [-0.2665, -1.6142, -1.8564],
    ...3621, -0.0953,  0.9638],
         [-0.1135, -0.3110, -0.0841],
         [-0.3144,  1.4931,  1.6451]]], device='cuda:0')

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        d = input.dtype
        deepspeed_is_initialized = (
            deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
        )
        if self.precision is not None:
            with torch.amp.autocast("cuda", enabled=False):
                bias = (
                    self.bias.to(dtype=self.precision)
                    if self.bias is not None
                    else None
                )
                return nn.functional.linear(
                    input.to(dtype=self.precision),
                    self.weight.to(dtype=self.precision),
                    bias,
                ).to(dtype=d)

        if d is torch.bfloat16 and not deepspeed_is_initialized:
            with torch.amp.autocast("cuda", enabled=False):
                bias = self.bias.to(dtype=d) if self.bias is not None else None
                return nn.functional.linear(input, self.weight.to(dtype=d), bias)

>       return nn.functional.linear(input, self.weight, self.bias)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       RuntimeError: CUDA error: HIPBLAS_STATUS_INTERNAL_ERROR when calling hipblasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 128 n 474 k 3 lda 3 ldb 3 ldc 128 abcType 0 computeType 2 scaleType 0

openfold3/core/model/primitives/linear.py:140: RuntimeError

@singagan
Copy link
Copy Markdown
Contributor Author

singagan commented Apr 9, 2026

Observed this on specific systems while using hipBLASLt. We switched to rocBLAS, which was already in import_utils.py for inference, but not applied during tests. Added a session-scoped autouse fixture in conftest.py to apply the same setting for the full test suite.

Copy link
Copy Markdown
Contributor

@jnwei jnwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, with the latest conftest fixture for specifying rocBLAS, I see that the of3_model tests and the kernel tests all pass.

Thank you for your contribution to the OpenFold community @singagan ! I am sure many users will find these additions useful.

@jnwei jnwei added safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. and removed safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing. labels Apr 9, 2026
@jandom jandom merged commit 747aefc into aqlaboratory:main Apr 9, 2026
5 checks passed
@Emrys-Merlin Emrys-Merlin mentioned this pull request Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

safe-to-test Internal only label used to indicate PRs that are ready for automated CI testing.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AMD Implementation Install Instructions

4 participants